{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cpu')\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "learning_rate = 1e-6\n",
    "\n",
    "x = torch.randn(N, D_in, device=device)\n",
    "y = torch.randn(N, D_out, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 668.4280395507812\n",
      "1 619.3014526367188\n",
      "2 576.7938232421875\n",
      "3 538.9149780273438\n",
      "4 505.1468200683594\n",
      "5 474.6072998046875\n",
      "6 446.716064453125\n",
      "7 420.95367431640625\n",
      "8 397.27410888671875\n",
      "9 375.32220458984375\n",
      "10 354.959228515625\n",
      "11 336.01031494140625\n",
      "12 318.1667175292969\n",
      "13 301.3739318847656\n",
      "14 285.44549560546875\n",
      "15 270.3079833984375\n",
      "16 255.87188720703125\n",
      "17 242.06858825683594\n",
      "18 228.9234161376953\n",
      "19 216.41441345214844\n",
      "20 204.55685424804688\n",
      "21 193.25192260742188\n",
      "22 182.4688262939453\n",
      "23 172.12826538085938\n",
      "24 162.2840576171875\n",
      "25 152.9571533203125\n",
      "26 144.09112548828125\n",
      "27 135.6737823486328\n",
      "28 127.70211791992188\n",
      "29 120.16284942626953\n",
      "30 113.06999969482422\n",
      "31 106.37980651855469\n",
      "32 100.05526733398438\n",
      "33 94.0878677368164\n",
      "34 88.4574966430664\n",
      "35 83.14505004882812\n",
      "36 78.1514663696289\n",
      "37 73.44068145751953\n",
      "38 69.0069808959961\n",
      "39 64.8420181274414\n",
      "40 60.924922943115234\n",
      "41 57.24131393432617\n",
      "42 53.78363037109375\n",
      "43 50.530086517333984\n",
      "44 47.4709587097168\n",
      "45 44.59964370727539\n",
      "46 41.905784606933594\n",
      "47 39.37813186645508\n",
      "48 37.00717544555664\n",
      "49 34.780757904052734\n",
      "50 32.692893981933594\n",
      "51 30.736156463623047\n",
      "52 28.90078353881836\n",
      "53 27.179122924804688\n",
      "54 25.557729721069336\n",
      "55 24.03583526611328\n",
      "56 22.60902976989746\n",
      "57 21.269756317138672\n",
      "58 20.013269424438477\n",
      "59 18.833091735839844\n",
      "60 17.72444725036621\n",
      "61 16.685829162597656\n",
      "62 15.710694313049316\n",
      "63 14.796137809753418\n",
      "64 13.936905860900879\n",
      "65 13.130857467651367\n",
      "66 12.373797416687012\n",
      "67 11.663031578063965\n",
      "68 10.995936393737793\n",
      "69 10.369232177734375\n",
      "70 9.78137493133545\n",
      "71 9.228534698486328\n",
      "72 8.710663795471191\n",
      "73 8.223787307739258\n",
      "74 7.767024040222168\n",
      "75 7.337453365325928\n",
      "76 6.933601379394531\n",
      "77 6.553715229034424\n",
      "78 6.1958208084106445\n",
      "79 5.859481334686279\n",
      "80 5.542868614196777\n",
      "81 5.244792938232422\n",
      "82 4.963951587677002\n",
      "83 4.699268817901611\n",
      "84 4.449596405029297\n",
      "85 4.2145562171936035\n",
      "86 3.9926838874816895\n",
      "87 3.7836391925811768\n",
      "88 3.5864405632019043\n",
      "89 3.4002845287323\n",
      "90 3.2245118618011475\n",
      "91 3.05846905708313\n",
      "92 2.9012093544006348\n",
      "93 2.7528982162475586\n",
      "94 2.6127452850341797\n",
      "95 2.4801535606384277\n",
      "96 2.3548946380615234\n",
      "97 2.2367923259735107\n",
      "98 2.125178337097168\n",
      "99 2.019498109817505\n",
      "100 1.9196205139160156\n",
      "101 1.8249597549438477\n",
      "102 1.7352607250213623\n",
      "103 1.650418996810913\n",
      "104 1.5700196027755737\n",
      "105 1.4937825202941895\n",
      "106 1.421539068222046\n",
      "107 1.3531444072723389\n",
      "108 1.2882407903671265\n",
      "109 1.2267110347747803\n",
      "110 1.168365478515625\n",
      "111 1.1129708290100098\n",
      "112 1.06037437915802\n",
      "113 1.0103672742843628\n",
      "114 0.9630045890808105\n",
      "115 0.9180276393890381\n",
      "116 0.8753237724304199\n",
      "117 0.8347600698471069\n",
      "118 0.7963154911994934\n",
      "119 0.7597888112068176\n",
      "120 0.7251119017601013\n",
      "121 0.6921564936637878\n",
      "122 0.6608198285102844\n",
      "123 0.6310104131698608\n",
      "124 0.6026656031608582\n",
      "125 0.5756837129592896\n",
      "126 0.5499692559242249\n",
      "127 0.5254989266395569\n",
      "128 0.5022514462471008\n",
      "129 0.48006471991539\n",
      "130 0.4589249789714813\n",
      "131 0.43878281116485596\n",
      "132 0.4195723235607147\n",
      "133 0.40128132700920105\n",
      "134 0.3838435709476471\n",
      "135 0.3672121465206146\n",
      "136 0.3513574004173279\n",
      "137 0.33622273802757263\n",
      "138 0.3217832148075104\n",
      "139 0.3079968988895416\n",
      "140 0.2948598861694336\n",
      "141 0.28231483697891235\n",
      "142 0.2703128755092621\n",
      "143 0.2588854432106018\n",
      "144 0.24796921014785767\n",
      "145 0.23754766583442688\n",
      "146 0.22759704291820526\n",
      "147 0.2180878221988678\n",
      "148 0.20900553464889526\n",
      "149 0.20032700896263123\n",
      "150 0.1920250505208969\n",
      "151 0.18409015238285065\n",
      "152 0.17651577293872833\n",
      "153 0.16927050054073334\n",
      "154 0.16234105825424194\n",
      "155 0.1557142585515976\n",
      "156 0.14937591552734375\n",
      "157 0.14331094920635223\n",
      "158 0.1375124454498291\n",
      "159 0.1319582462310791\n",
      "160 0.12664245069026947\n",
      "161 0.12155882269144058\n",
      "162 0.11669367551803589\n",
      "163 0.11203818768262863\n",
      "164 0.10757976770401001\n",
      "165 0.1033095046877861\n",
      "166 0.09921789169311523\n",
      "167 0.09530003368854523\n",
      "168 0.09154531359672546\n",
      "169 0.08794800192117691\n",
      "170 0.08451281487941742\n",
      "171 0.08122212439775467\n",
      "172 0.07806745916604996\n",
      "173 0.07505367696285248\n",
      "174 0.07216597348451614\n",
      "175 0.06939806789159775\n",
      "176 0.0667470172047615\n",
      "177 0.06420008093118668\n",
      "178 0.061755359172821045\n",
      "179 0.05941098555922508\n",
      "180 0.05716121196746826\n",
      "181 0.05500328168272972\n",
      "182 0.05293368920683861\n",
      "183 0.0509478934109211\n",
      "184 0.0490410178899765\n",
      "185 0.04720799997448921\n",
      "186 0.04544854909181595\n",
      "187 0.04376072809100151\n",
      "188 0.04213841259479523\n",
      "189 0.04058206453919411\n",
      "190 0.03908592462539673\n",
      "191 0.03764795884490013\n",
      "192 0.03626532480120659\n",
      "193 0.03493737056851387\n",
      "194 0.033663060516119\n",
      "195 0.032437656074762344\n",
      "196 0.031259458512067795\n",
      "197 0.030127018690109253\n",
      "198 0.029037170112133026\n",
      "199 0.027989309281110764\n",
      "200 0.02698296122252941\n",
      "201 0.02601522207260132\n",
      "202 0.02508455514907837\n",
      "203 0.024188820272684097\n",
      "204 0.02332720346748829\n",
      "205 0.022497866302728653\n",
      "206 0.021699873730540276\n",
      "207 0.020932672545313835\n",
      "208 0.020194221287965775\n",
      "209 0.01948348991572857\n",
      "210 0.018798673525452614\n",
      "211 0.018139351159334183\n",
      "212 0.01750517450273037\n",
      "213 0.016894254833459854\n",
      "214 0.016306255012750626\n",
      "215 0.01574067585170269\n",
      "216 0.015196195803582668\n",
      "217 0.014671390876173973\n",
      "218 0.014165820553898811\n",
      "219 0.013678548857569695\n",
      "220 0.013208937831223011\n",
      "221 0.012756681069731712\n",
      "222 0.012320631183683872\n",
      "223 0.011901373974978924\n",
      "224 0.011496649123728275\n",
      "225 0.011106884106993675\n",
      "226 0.010730779729783535\n",
      "227 0.010368099436163902\n",
      "228 0.010018656961619854\n",
      "229 0.009681539610028267\n",
      "230 0.009356922470033169\n",
      "231 0.00904351007193327\n",
      "232 0.008741225115954876\n",
      "233 0.008449605666100979\n",
      "234 0.008168323896825314\n",
      "235 0.007897168397903442\n",
      "236 0.007635283749550581\n",
      "237 0.007382792886346579\n",
      "238 0.007139251567423344\n",
      "239 0.006904060021042824\n",
      "240 0.006676878780126572\n",
      "241 0.0064580379985272884\n",
      "242 0.006247041281312704\n",
      "243 0.006043288856744766\n",
      "244 0.005846584681421518\n",
      "245 0.005656846333295107\n",
      "246 0.005473543889820576\n",
      "247 0.0052964393980801105\n",
      "248 0.005125406663864851\n",
      "249 0.0049601392820477486\n",
      "250 0.004800703842192888\n",
      "251 0.0046466016210615635\n",
      "252 0.004497676622122526\n",
      "253 0.004353924188762903\n",
      "254 0.004215024411678314\n",
      "255 0.004080716986209154\n",
      "256 0.0039510722272098064\n",
      "257 0.0038256857078522444\n",
      "258 0.003704528324306011\n",
      "259 0.0035873299930244684\n",
      "260 0.003474246943369508\n",
      "261 0.0033648228272795677\n",
      "262 0.0032589694019407034\n",
      "263 0.0031566026154905558\n",
      "264 0.0030576346907764673\n",
      "265 0.002962043508887291\n",
      "266 0.0028695915825664997\n",
      "267 0.002780200680717826\n",
      "268 0.002693787682801485\n",
      "269 0.0026101593393832445\n",
      "270 0.0025292756035923958\n",
      "271 0.002450983040034771\n",
      "272 0.002375258831307292\n",
      "273 0.0023019565269351006\n",
      "274 0.0022311010397970676\n",
      "275 0.0021625033114105463\n",
      "276 0.002096213400363922\n",
      "277 0.002031973795965314\n",
      "278 0.0019698659889400005\n",
      "279 0.0019097381737083197\n",
      "280 0.0018515075789764524\n",
      "281 0.0017952037742361426\n",
      "282 0.001740649458952248\n",
      "283 0.001687859185039997\n",
      "284 0.0016367597272619605\n",
      "285 0.0015872701769694686\n",
      "286 0.0015393750509247184\n",
      "287 0.0014929687604308128\n",
      "288 0.0014480563113465905\n",
      "289 0.0014045407297089696\n",
      "290 0.0013623997801914811\n",
      "291 0.0013216090155765414\n",
      "292 0.0012820754200220108\n",
      "293 0.0012438116827979684\n",
      "294 0.001206719083711505\n",
      "295 0.0011708020465448499\n",
      "296 0.0011360307689756155\n",
      "297 0.0011023295810446143\n",
      "298 0.0010696646058931947\n",
      "299 0.0010380144231021404\n",
      "300 0.0010073393350467086\n",
      "301 0.0009776894003152847\n",
      "302 0.0009489087387919426\n",
      "303 0.0009210064308717847\n",
      "304 0.0008939398103393614\n",
      "305 0.0008677291334606707\n",
      "306 0.0008423201506957412\n",
      "307 0.0008176830597221851\n",
      "308 0.0007937990594655275\n",
      "309 0.000770646205637604\n",
      "310 0.0007482254295609891\n",
      "311 0.000726468802895397\n",
      "312 0.0007053824374452233\n",
      "313 0.0006849327473901212\n",
      "314 0.0006650997092947364\n",
      "315 0.0006458699353970587\n",
      "316 0.0006272243335843086\n",
      "317 0.0006091536488384008\n",
      "318 0.0005915929796174169\n",
      "319 0.0005745812668465078\n",
      "320 0.0005580888246186078\n",
      "321 0.0005420905654318631\n",
      "322 0.0005265663494355977\n",
      "323 0.0005114917294122279\n",
      "324 0.0004968929570168257\n",
      "325 0.0004827164229936898\n",
      "326 0.0004689573834184557\n",
      "327 0.00045560969738289714\n",
      "328 0.0004426606756169349\n",
      "329 0.00043010988156311214\n",
      "330 0.00041792294359765947\n",
      "331 0.00040608839481137693\n",
      "332 0.00039461645064875484\n",
      "333 0.0003834891540464014\n",
      "334 0.0003726732393261045\n",
      "335 0.0003621804353315383\n",
      "336 0.0003519785823300481\n",
      "337 0.00034208313445560634\n",
      "338 0.0003324836725369096\n",
      "339 0.0003231620357837528\n",
      "340 0.000314119242830202\n",
      "341 0.0003053498512599617\n",
      "342 0.00029681934393011034\n",
      "343 0.0002885369467549026\n",
      "344 0.00028050635592080653\n",
      "345 0.0002726892998907715\n",
      "346 0.0002651024260558188\n",
      "347 0.00025774171808734536\n",
      "348 0.00025059128529392183\n",
      "349 0.00024366166326217353\n",
      "350 0.00023692226386629045\n",
      "351 0.00023036645143292844\n",
      "352 0.00022401117894332856\n",
      "353 0.00021783517149742693\n",
      "354 0.00021182910131756216\n",
      "355 0.0002060057013295591\n",
      "356 0.0002003381960093975\n",
      "357 0.00019484169024508446\n",
      "358 0.00018949991499539465\n",
      "359 0.00018430834461469203\n",
      "360 0.00017926192958839238\n",
      "361 0.00017436128109693527\n",
      "362 0.00016960060747805983\n",
      "363 0.00016496855823788792\n",
      "364 0.000160473573487252\n",
      "365 0.00015611042908858508\n",
      "366 0.00015186684322543442\n",
      "367 0.0001477432087995112\n",
      "368 0.00014373086742125452\n",
      "369 0.00013983171083964407\n",
      "370 0.0001360469905193895\n",
      "371 0.00013236647646408528\n",
      "372 0.00012879098358098418\n",
      "373 0.00012530872481875122\n",
      "374 0.00012192891881568357\n",
      "375 0.00011864048428833485\n",
      "376 0.00011545195593498647\n",
      "377 0.00011234574049012735\n",
      "378 0.00010932660370599478\n",
      "379 0.00010638883395586163\n",
      "380 0.0001035374152706936\n",
      "381 0.00010075899626826867\n",
      "382 9.80704280664213e-05\n",
      "383 9.544336353428662e-05\n",
      "384 9.28973822738044e-05\n",
      "385 9.041393059305847e-05\n",
      "386 8.800379873719066e-05\n",
      "387 8.565903408452868e-05\n",
      "388 8.337706094607711e-05\n",
      "389 8.115889795590192e-05\n",
      "390 7.900226046331227e-05\n",
      "391 7.690395432291552e-05\n",
      "392 7.486288086511195e-05\n",
      "393 7.287965127034113e-05\n",
      "394 7.095025648595765e-05\n",
      "395 6.907253555255011e-05\n",
      "396 6.724859122186899e-05\n",
      "397 6.547294469783083e-05\n",
      "398 6.374534132191911e-05\n",
      "399 6.206848775036633e-05\n",
      "400 6.04308042966295e-05\n",
      "401 5.883846097276546e-05\n",
      "402 5.728831820306368e-05\n",
      "403 5.57846506126225e-05\n",
      "404 5.431690806290135e-05\n",
      "405 5.289401815389283e-05\n",
      "406 5.150696961209178e-05\n",
      "407 5.015786155126989e-05\n",
      "408 4.88449186377693e-05\n",
      "409 4.756770431413315e-05\n",
      "410 4.6322267735376954e-05\n",
      "411 4.5114073145668954e-05\n",
      "412 4.3935335270361975e-05\n",
      "413 4.279126733308658e-05\n",
      "414 4.1677598346723244e-05\n",
      "415 4.059138154843822e-05\n",
      "416 3.95344068238046e-05\n",
      "417 3.8505149859702215e-05\n",
      "418 3.750806354219094e-05\n",
      "419 3.653280145954341e-05\n",
      "420 3.558673779480159e-05\n",
      "421 3.4662858524825424e-05\n",
      "422 3.376556560397148e-05\n",
      "423 3.2891832233872265e-05\n",
      "424 3.204200402251445e-05\n",
      "425 3.121293048025109e-05\n",
      "426 3.040721276192926e-05\n",
      "427 2.962217877211515e-05\n",
      "428 2.885920366679784e-05\n",
      "429 2.8112779546063393e-05\n",
      "430 2.7388592570787296e-05\n",
      "431 2.6686288038035855e-05\n",
      "432 2.5999363060691394e-05\n",
      "433 2.5330125936307013e-05\n",
      "434 2.467943704687059e-05\n",
      "435 2.4044325982686132e-05\n",
      "436 2.3429380235029384e-05\n",
      "437 2.2828267901786603e-05\n",
      "438 2.224343188572675e-05\n",
      "439 2.1673749870387837e-05\n",
      "440 2.111853245878592e-05\n",
      "441 2.0578618205036037e-05\n",
      "442 2.0053332264069468e-05\n",
      "443 1.954195431608241e-05\n",
      "444 1.9043218344449997e-05\n",
      "445 1.8557506336946972e-05\n",
      "446 1.8083488612319343e-05\n",
      "447 1.7623133317101747e-05\n",
      "448 1.717328450467903e-05\n",
      "449 1.673746737651527e-05\n",
      "450 1.6310717910528183e-05\n",
      "451 1.5896104741841555e-05\n",
      "452 1.5493104001507163e-05\n",
      "453 1.510041875008028e-05\n",
      "454 1.4716693840455264e-05\n",
      "455 1.4343011571327224e-05\n",
      "456 1.3979440154798795e-05\n",
      "457 1.362533112114761e-05\n",
      "458 1.3280245184432715e-05\n",
      "459 1.2944614354637451e-05\n",
      "460 1.2616713320312556e-05\n",
      "461 1.2298514775466174e-05\n",
      "462 1.1987803191004787e-05\n",
      "463 1.168484322988661e-05\n",
      "464 1.1389593964850064e-05\n",
      "465 1.1102847565780394e-05\n",
      "466 1.0823197044373956e-05\n",
      "467 1.0550974366196897e-05\n",
      "468 1.0284874406352174e-05\n",
      "469 1.0025214578490704e-05\n",
      "470 9.77350828179624e-06\n",
      "471 9.528063856123481e-06\n",
      "472 9.28873305383604e-06\n",
      "473 9.055346708919387e-06\n",
      "474 8.828628779156134e-06\n",
      "475 8.60660475154873e-06\n",
      "476 8.390418770432007e-06\n",
      "477 8.179761607607361e-06\n",
      "478 7.975531843840145e-06\n",
      "479 7.775297490297817e-06\n",
      "480 7.58039368520258e-06\n",
      "481 7.391093731712317e-06\n",
      "482 7.206848749774508e-06\n",
      "483 7.026189450698439e-06\n",
      "484 6.850876161479391e-06\n",
      "485 6.679532361886231e-06\n",
      "486 6.512843356176745e-06\n",
      "487 6.3503302953904495e-06\n",
      "488 6.191920874698553e-06\n",
      "489 6.037036200723378e-06\n",
      "490 5.887446150154574e-06\n",
      "491 5.7401639423915185e-06\n",
      "492 5.597332346951589e-06\n",
      "493 5.457698989630444e-06\n",
      "494 5.322172455635155e-06\n",
      "495 5.190152933209902e-06\n",
      "496 5.061883712187409e-06\n",
      "497 4.935528522764798e-06\n",
      "498 4.812867700820789e-06\n",
      "499 4.6928853407735005e-06\n"
     ]
    }
   ],
   "source": [
    "class TwoLayerNet(torch.nn.Module):\n",
    "    def __init__(self, D_in, H, D_out):\n",
    "        super().__init__()\n",
    "        self.fc1 = torch.nn.Linear(D_in, H)\n",
    "        self.fc2 = torch.nn.Linear(H, D_out)\n",
    "    def forward(self, x):\n",
    "        h_relu = self.fc1(x).clamp(min=0)\n",
    "        y_pred = self.fc2(h_relu)\n",
    "        return y_pred\n",
    "    \n",
    "model = TwoLayerNet(D_in, H, D_out)\n",
    "loss_fn = torch.nn.MSELoss(size_average=False)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)\n",
    "for t in range(500):\n",
    "    y_pred = model(x)\n",
    "    loss = loss_fn(y_pred, y)\n",
    "    print(t, loss.item())\n",
    "    \n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
